{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Encoder/decoder transformer, like in the paper\n",
    "\n",
    "# Single Encoder block\n",
    "\n",
    "# Embed input into a length 512 vector\n",
    "# Add positional encoding to the input\n",
    "# Alternating sin and cosine functions across the embedding dimension (512)\n",
    "# i = position in the embedding dimension\n",
    "#PE(pos,2i) = sin(pos/100002i/dmodel )\n",
    "#PE(pos,2i+1) = cos(pos/100002i/dmodel\n",
    "# Sum positional encoding and input\n",
    "# apply dropout with chance .1\n",
    "\n",
    "# Self-attention\n",
    "# Each token in the input sequence is a query\n",
    "# Every other token is a key\n",
    "# Project query and key into a new dimensional space with a linear transform (this is multi-head attention)\n",
    "# We do the following 16 times for multi-head attention\n",
    "# Dot the projects query with the projected key matrix\n",
    "# QK\n",
    "# Attention(Q, K, V ) = softmax( QKT / √dk)V\n",
    "# √dk is the square root of the value dimension\n",
    "# Keys and values are the same in this particular case\n",
    "# Basically, compute an attention score (scalar) between each query and each key, then scale the value by that number\n",
    "# So less relevant other tokens are minimized\n",
    "# Concat the results of the attention equation\n",
    "# Run through another linear layer to reproject\n",
    "# Each attention head outputs 1/16th of the input embedding len\n",
    "\n",
    "# After doing multi-head attention (16)\n",
    "# Apply dropout with chance .1\n",
    "# Add the input to the layer (original embedded sequences) and the output of attention\n",
    "# Run layer normalization (unclear which layer norm to use)\n",
    "\n",
    "# Run a feed forward network\n",
    "# Add input to the layer to the output of the ff network\n",
    "# Normalize again\n",
    "\n",
    "# Single decoder block\n",
    "\n",
    "# Shift outputs right, to start with start token\n",
    "# Do positional encoding\n",
    "# Do dropout with chance .1\n",
    "# Mask outputs, so queries can only see keys that came before the query\n",
    "# Run multi-head attention\n",
    "# Add and norm\n",
    "\n",
    "# Run multi-head attention again, but this time v,k is from encoder stack, and q is from decoder stack\n",
    "# Apply dropout with chance .1\n",
    "# When doing add and norm, add in the decoder stack input\n",
    "\n",
    "# Feed forward\n",
    "\n",
    "# At top of stack, do another linear layer and softmax\n",
    "\n",
    "# Might want to move layer norm inside the residual block - https://arxiv.org/pdf/2002.04745.pdf\n",
    "# Layer normalization - https://arxiv.org/pdf/1607.06450.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/vik/.virtualenvs/nnets/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import functorch\n",
    "import sys\n",
    "import os\n",
    "import math\n",
    "sys.path.append(os.path.abspath(\"../../data\"))\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "BATCH_SIZE = 1\n",
    "SP_VOCAB_SIZE = 1000\n",
    "TRAIN_SIZE = 500"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
      "100%|██████████| 3/3 [00:00<00:00, 45.27it/s]\n",
      "sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn_dailymail --vocab_size=1000 --model_type=unigram\n",
      "sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : \n",
      "trainer_spec {\n",
      "  input: tokens.txt\n",
      "  input_format: \n",
      "  model_prefix: cnn_dailymail\n",
      "  model_type: UNIGRAM\n",
      "  vocab_size: 1000\n",
      "  self_test_sample_size: 0\n",
      "  character_coverage: 0.9995\n",
      "  input_sentence_size: 0\n",
      "  shuffle_input_sentence: 1\n",
      "  seed_sentencepiece_size: 1000000\n",
      "  shrinking_factor: 0.75\n",
      "  max_sentence_length: 4192\n",
      "  num_threads: 16\n",
      "  num_sub_iterations: 2\n",
      "  max_sentencepiece_length: 16\n",
      "  split_by_unicode_script: 1\n",
      "  split_by_number: 1\n",
      "  split_by_whitespace: 1\n",
      "  split_digits: 0\n",
      "  treat_whitespace_as_suffix: 0\n",
      "  required_chars: \n",
      "  byte_fallback: 0\n",
      "  vocabulary_output_piece_score: 1\n",
      "  train_extremely_large_corpus: 0\n",
      "  hard_vocab_limit: 1\n",
      "  use_all_vocab: 0\n",
      "  unk_id: 0\n",
      "  bos_id: 1\n",
      "  eos_id: 2\n",
      "  pad_id: -1\n",
      "  unk_piece: <unk>\n",
      "  bos_piece: <s>\n",
      "  eos_piece: </s>\n",
      "  pad_piece: <pad>\n",
      "  unk_surface:  ⁇ \n",
      "}\n",
      "normalizer_spec {\n",
      "  name: nmt_nfkc\n",
      "  add_dummy_prefix: 1\n",
      "  remove_extra_whitespaces: 1\n",
      "  escape_whitespaces: 1\n",
      "  normalization_rule_tsv: \n",
      "}\n",
      "denormalizer_spec {}\n",
      "trainer_interface.cc(319) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.\n",
      "trainer_interface.cc(174) LOG(INFO) Loading corpus: tokens.txt\n",
      "trainer_interface.cc(375) LOG(INFO) Loaded all 2161 sentences\n",
      "trainer_interface.cc(390) LOG(INFO) Adding meta_piece: <unk>\n",
      "trainer_interface.cc(390) LOG(INFO) Adding meta_piece: <s>\n",
      "trainer_interface.cc(390) LOG(INFO) Adding meta_piece: </s>\n",
      "trainer_interface.cc(395) LOG(INFO) Normalizing sentences...\n",
      "trainer_interface.cc(456) LOG(INFO) all chars count=101295\n",
      "trainer_interface.cc(467) LOG(INFO) Done: 99.9526% characters are covered.\n",
      "trainer_interface.cc(477) LOG(INFO) Alphabet size=68\n",
      "trainer_interface.cc(478) LOG(INFO) Final character coverage=0.999526\n",
      "trainer_interface.cc(510) LOG(INFO) Done! preprocessed 2161 sentences.\n",
      "unigram_model_trainer.cc(138) LOG(INFO) Making suffix array...\n",
      "unigram_model_trainer.cc(142) LOG(INFO) Extracting frequent sub strings...\n",
      "unigram_model_trainer.cc(193) LOG(INFO) Initialized 11093 seed sentencepieces\n",
      "trainer_interface.cc(516) LOG(INFO) Tokenizing input sentences with whitespace: 2161\n",
      "trainer_interface.cc(526) LOG(INFO) Done! 6025\n",
      "unigram_model_trainer.cc(488) LOG(INFO) Using 6025 sentences for EM training\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=4758 obj=13.0431 num_tokens=13019 num_tokens/piece=2.73623\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=4170 obj=11.4914 num_tokens=13140 num_tokens/piece=3.15108\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=3123 obj=11.5758 num_tokens=14016 num_tokens/piece=4.48799\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=3120 obj=11.4633 num_tokens=14040 num_tokens/piece=4.5\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=2340 obj=11.859 num_tokens=15424 num_tokens/piece=6.59145\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=2340 obj=11.721 num_tokens=15431 num_tokens/piece=6.59444\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1755 obj=12.3254 num_tokens=17151 num_tokens/piece=9.77265\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1755 obj=12.1667 num_tokens=17155 num_tokens/piece=9.77493\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1316 obj=12.8195 num_tokens=19020 num_tokens/piece=14.4529\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1316 obj=12.6626 num_tokens=19020 num_tokens/piece=14.4529\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1100 obj=13.101 num_tokens=20116 num_tokens/piece=18.2873\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1100 obj=13.0108 num_tokens=20117 num_tokens/piece=18.2882\n",
      "trainer_interface.cc(604) LOG(INFO) Saving model: cnn_dailymail.model\n",
      "trainer_interface.cc(615) LOG(INFO) Saving vocabs: cnn_dailymail.vocab\n"
     ]
    }
   ],
   "source": [
    "from text_data import CNNDatasetWrapper\n",
    "\n",
    "class Wrapper(CNNDatasetWrapper):\n",
    "    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]\n",
    "    x_length = 15\n",
    "    target_length = 15\n",
    "\n",
    "wrapper = Wrapper(SP_VOCAB_SIZE, DEVICE)\n",
    "\n",
    "datasets = wrapper.generate_datasets(BATCH_SIZE)\n",
    "train = datasets[\"train\"]\n",
    "valid = datasets[\"validation\"]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "x, y, prev_y = train.dataset[0]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "embed = nn.Embedding(wrapper.vocab_size, 512)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([15, 512])"
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_embed = embed(x).shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#PE(pos,2i) = sin(pos/100002i/dmodel )\n",
    "#PE(pos,2i+1) = cos(pos/100002i/dmodel"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
